"""
    DDPM utils, see Songs paper
"""

import torch
from torch import Tensor


class DiffusionProcess:
    """
    DDPM Diffusion
    """

    def __init__(self, beta_min=0.1, beta_max=20.0):
        self.beta_min = beta_min
        self.beta_max = beta_max

    def beta(self, time: Tensor) -> Tensor:
        """
        Calculate beta at given time.
        """
        return (self.beta_max - self.beta_min) * time.clone().detach() + self.beta_min

    def marginal_prob_std(self, time: Tensor) -> Tensor:
        """
        Calculate the marginal probability standard deviation.
        """
        return torch.sqrt(
            1
            - torch.exp(
                -0.5 * time**2 * (self.beta_max - self.beta_min) - time * self.beta_min
            )
        )

    def marginal_prob_mean_factor(self, time: Tensor) -> Tensor:
        """
        Calculate the marginal probability mean factor.
        """
        return torch.exp(
            -0.25 * (time**2) * (self.beta_max - self.beta_min)
            - 0.5 * time * self.beta_min
        )

    def diffusion_coeff(self, time: Tensor) -> Tensor:
        """
        Calculate the diffusion coefficient.
        """
        beta_t = self.beta(time)
        return torch.sqrt(beta_t)

    def loss_fn(
        self, model: torch.nn.Module, x_in: Tensor, eps: float = 1e-5
    ) -> Tensor:
        """
        Calculate the loss function for the given model and input.
        """
        random_t = (
            torch.rand((x_in.shape[0], 1), device=x_in.device) * (1.0 - eps) + eps
        )
        random_x = torch.randn_like(x_in)
        std = self.marginal_prob_std(random_t.squeeze(1))
        mean_factor = self.marginal_prob_mean_factor(random_t.squeeze(1))

        if torch.isnan(std).any() or torch.isinf(std).any():
            print(
                f"NaN or Inf in std, t was between: {torch.min(random_t), torch.max(random_t)}"
            )
        if torch.isnan(mean_factor).any() or torch.isinf(mean_factor).any():
            print(
                f"NaN or Inf in std, t was between: {torch.min(random_t), torch.max(random_t)}"
            )

        std_shape = [-1] + [1] * (x_in.dim() - 1)
        std_expanded = std.view(*std_shape)
        mean_factor_expanded = mean_factor.view(*std_shape)

        perturbed_x = mean_factor_expanded * x_in + random_x * std_expanded
        score = model(time=random_t, x_batch=perturbed_x)
        if torch.isnan(score).any() or torch.isinf(score).any():
            print("Score was infinite")

        loss = torch.mean(
            torch.sum(
                (score * std_expanded + random_x) ** 2, dim=tuple(range(1, x_in.dim()))
            )
        )
        return loss
